package db

import (
	"database/sql"
	"encoding/json"
	"errors"
	"fmt"
	"log"
	"os"
	"time"

	"github.com/cheggaaa/pb/v3"
	"github.com/glebarez/sqlite"
	"github.com/knqyf263/go-cpe/common"
	"github.com/knqyf263/go-cpe/naming"
	"github.com/spf13/viper"
	"github.com/vulsio/go-cve-dictionary/config"
	"github.com/vulsio/go-cve-dictionary/fetcher/jvn"
	"github.com/vulsio/go-cve-dictionary/fetcher/nvd"
	logger "github.com/vulsio/go-cve-dictionary/log"
	"github.com/vulsio/go-cve-dictionary/models"
	"golang.org/x/xerrors"
	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"gorm.io/gorm"
	gormLogger "gorm.io/gorm/logger"
)

// Supported DB dialects.
const (
	dialectSqlite3    = "sqlite3"
	dialectMysql      = "mysql"
	dialectPostgreSQL = "postgres"
)

// RDBDriver is Driver for RDB
type RDBDriver struct {
	name string
	conn *gorm.DB
}

// https://github.com/mattn/go-sqlite3/blob/edc3bb69551dcfff02651f083b21f3366ea2f5ab/error.go#L18-L66
type errNo int

type sqliteError struct {
	Code errNo /* The error code returned by SQLite */
}

// result codes from http://www.sqlite.org/c3ref/c_abort.html
var (
	errBusy   = errNo(5) /* The database file is locked */
	errLocked = errNo(6) /* A table in the database is locked */
)

// ErrDBLocked :
var ErrDBLocked = xerrors.New("database is locked")

// Name return db name
func (r *RDBDriver) Name() string {
	return r.name
}

// OpenDB opens Database
func (r *RDBDriver) OpenDB(dbType, dbPath string, debugSQL bool, _ Option) (err error) {
	gormConfig := gorm.Config{
		DisableForeignKeyConstraintWhenMigrating: true,
		Logger: gormLogger.New(
			log.New(os.Stderr, "\r\n", log.LstdFlags),
			gormLogger.Config{
				LogLevel: gormLogger.Silent,
			},
		),
	}

	if debugSQL {
		gormConfig.Logger = gormLogger.New(
			log.New(os.Stderr, "\r\n", log.LstdFlags),
			gormLogger.Config{
				SlowThreshold: time.Second,
				LogLevel:      gormLogger.Info,
				Colorful:      true,
			},
		)
	}

	switch r.name {
	case dialectSqlite3:
		r.conn, err = gorm.Open(sqlite.Open(dbPath), &gormConfig)
		if err != nil {
			parsedErr, marshalErr := json.Marshal(err)
			if marshalErr != nil {
				return xerrors.Errorf("Failed to marshal err. err: %w", marshalErr)
			}

			var errMsg sqliteError
			if unmarshalErr := json.Unmarshal(parsedErr, &errMsg); unmarshalErr != nil {
				return xerrors.Errorf("Failed to unmarshal. err: %w", unmarshalErr)
			}

			switch errMsg.Code {
			case errBusy, errLocked:
				return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dbType, dbPath, ErrDBLocked)
			default:
				return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dbType, dbPath, err)
			}
		}

		r.conn.Exec("PRAGMA foreign_keys = ON")
	case dialectMysql:
		r.conn, err = gorm.Open(mysql.Open(dbPath), &gormConfig)
		if err != nil {
			return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dbType, dbPath, err)
		}
	case dialectPostgreSQL:
		r.conn, err = gorm.Open(postgres.Open(dbPath), &gormConfig)
		if err != nil {
			return xerrors.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %w", dbType, dbPath, err)
		}
	default:
		return xerrors.Errorf("Not Supported DB dialects. r.name: %s", r.name)
	}
	return nil
}

// CloseDB close Database
func (r *RDBDriver) CloseDB() (err error) {
	if r.conn == nil {
		return
	}

	var sqlDB *sql.DB
	if sqlDB, err = r.conn.DB(); err != nil {
		return xerrors.Errorf("Failed to get DB Object. err : %w", err)
	}
	if err = sqlDB.Close(); err != nil {
		return xerrors.Errorf("Failed to close DB. Type: %s. err: %w", r.name, err)
	}
	return
}

// MigrateDB migrates Database
func (r *RDBDriver) MigrateDB() error {
	if err := r.conn.AutoMigrate(
		&models.FetchMeta{},

		&models.Nvd{},
		&models.NvdDescription{},
		&models.NvdCvss2Extra{},
		&models.NvdCvss3{},
		&models.NvdCwe{},
		&models.NvdCpe{},
		&models.NvdEnvCpe{},
		&models.NvdReference{},
		&models.NvdCert{},

		&models.Jvn{},
		&models.JvnCvss2{},
		&models.JvnCvss3{},
		&models.JvnCpe{},
		&models.JvnReference{},
		&models.JvnCert{},
	); err != nil {
		switch r.name {
		case dialectSqlite3:
			if r.name == dialectSqlite3 {
				parsedErr, marshalErr := json.Marshal(err)
				if marshalErr != nil {
					return xerrors.Errorf("Failed to marshal err. err: %w", marshalErr)
				}

				var errMsg sqliteError
				if unmarshalErr := json.Unmarshal(parsedErr, &errMsg); unmarshalErr != nil {
					return xerrors.Errorf("Failed to unmarshal. err: %w", unmarshalErr)
				}

				switch errMsg.Code {
				case errBusy, errLocked:
					return xerrors.Errorf("Failed to migrate. err: %w", ErrDBLocked)
				default:
					return xerrors.Errorf("Failed to migrate. err: %w", err)
				}
			}
		case dialectMysql, dialectPostgreSQL:
			if err != nil {
				return xerrors.Errorf("Failed to migrate. err: %w", err)
			}
		default:
			return xerrors.Errorf("Not Supported DB dialects. r.name: %s", r.name)
		}
	}

	return nil
}

// IsGoCVEDictModelV1 determines if the DB was created at the time of go-cve-dictionary Model v1
func (r *RDBDriver) IsGoCVEDictModelV1() (bool, error) {
	if r.conn.Migrator().HasTable(&models.FetchMeta{}) {
		return false, nil
	}

	var (
		count int64
		err   error
	)
	switch r.name {
	case dialectSqlite3:
		err = r.conn.Table("sqlite_master").Where("type = ?", "table").Count(&count).Error
	case dialectMysql:
		err = r.conn.Table("information_schema.tables").Where("table_schema = ?", r.conn.Migrator().CurrentDatabase()).Count(&count).Error
	case dialectPostgreSQL:
		err = r.conn.Table("pg_tables").Where("schemaname = ?", "public").Count(&count).Error
	}

	if count > 0 {
		return true, nil
	}
	return false, err
}

// GetFetchMeta get FetchMeta from Database
func (r *RDBDriver) GetFetchMeta() (fetchMeta *models.FetchMeta, err error) {
	if err = r.conn.Take(&fetchMeta).Error; err != nil {
		if !errors.Is(err, gorm.ErrRecordNotFound) {
			return nil, err
		}
		return &models.FetchMeta{GoCVEDictRevision: config.Revision, SchemaVersion: models.LatestSchemaVersion, LastFetchedAt: time.Date(1000, time.January, 1, 0, 0, 0, 0, time.UTC)}, nil
	}

	return fetchMeta, nil
}

// UpsertFetchMeta upsert FetchMeta to Database
func (r *RDBDriver) UpsertFetchMeta(fetchMeta *models.FetchMeta) error {
	fetchMeta.GoCVEDictRevision = config.Revision
	fetchMeta.SchemaVersion = models.LatestSchemaVersion
	return r.conn.Save(fetchMeta).Error
}

// GetMulti Select Cves information from DB.
func (r *RDBDriver) GetMulti(cveIDs []string) (map[string]models.CveDetail, error) {
	cveDetails := map[string]models.CveDetail{}
	for _, cveID := range cveIDs {
		cve, err := r.Get(cveID)
		if err != nil {
			return nil, err
		}
		cveDetails[cveID] = *cve
	}
	return cveDetails, nil
}

// Get Select Cve information from DB.
func (r *RDBDriver) Get(cveID string) (*models.CveDetail, error) {
	detail := models.CveDetail{
		CveID: cveID,
	}
	if err := r.conn.
		Where(&models.Nvd{CveID: cveID}).
		Preload("Descriptions").
		Preload("Cvss2").
		Preload("Cvss3").
		Preload("Cwes").
		Preload("Cpes").
		Preload("References").
		Preload("Certs").
		Find(&detail.Nvds).Error; err != nil {
		return nil, xerrors.Errorf("Failed to fill Nvd. Nvd{CveID: %s} err: %w", cveID, err)
	}

	for i := range detail.Nvds {
		for j := range detail.Nvds[i].Cpes {
			if err := r.conn.
				Where(&models.NvdEnvCpe{NvdCpeID: uint(detail.Nvds[i].Cpes[j].ID)}).
				Find(&detail.Nvds[i].Cpes[j].EnvCpes).Error; err != nil {
				return nil, xerrors.Errorf("Failed to fill Nvd EnvCpes. Nvd{CveID: %s} Cpes:{NvdCpeID: %d} err: %w", cveID, uint(detail.Nvds[i].Cpes[j].ID), err)
			}
		}
	}

	if err := r.conn.
		Where(&models.Jvn{CveID: cveID}).
		Preload("Cvss2").
		Preload("Cvss3").
		Preload("Cpes").
		Preload("References").
		Preload("Certs").
		Find(&detail.Jvns).Error; err != nil {
		return nil, xerrors.Errorf("Failed to fill Jvn. Jvn{CveID: %s} err: %w", cveID, err)
	}

	return &detail, nil
}

func (r *RDBDriver) getCveIDsByPartVendorProduct(uri string) ([]string, error) {
	specified, err := naming.UnbindURI(uri)
	if err != nil {
		return nil, err
	}

	part := specified.GetString(common.AttributePart)
	vendor := specified.GetString(common.AttributeVendor)
	product := specified.GetString(common.AttributeProduct)

	nvds := []models.Nvd{}
	if err := r.conn.
		Select("nvds.cve_id").
		Joins("JOIN nvd_cpes ON nvd_cpes.nvd_id = nvds.id").
		Where("nvd_cpes.part = ? AND nvd_cpes.vendor = ? AND nvd_cpes.product = ?", part, vendor, product).
		Find(&nvds).Error; err != nil {
		return nil, err
	}

	jvns := []models.Jvn{}
	if err := r.conn.
		Select("jvns.cve_id").
		Joins("JOIN jvn_cpes ON jvn_cpes.jvn_id = jvns.id").
		Where("jvn_cpes.part = ? AND jvn_cpes.vendor = ? AND jvn_cpes.product = ?", part, vendor, product).
		Find(&jvns).Error; err != nil {
		return nil, err
	}

	cveIDs := []string{}
	for _, nvd := range nvds {
		cveIDs = append(cveIDs, nvd.CveID)
	}
	for _, jvn := range jvns {
		cveIDs = append(cveIDs, jvn.CveID)
	}

	return cveIDs, nil
}

// GetCveIDsByCpeURI Select Cve Ids by by pseudo-CPE
func (r *RDBDriver) GetCveIDsByCpeURI(uri string) (nvdCveIDs []string, jvnCveIDs []string, err error) {
	cveIDs, err := r.getCveIDsByPartVendorProduct(uri)
	if err != nil {
		return nil, nil, err
	}

	uniqCveIDs := map[string]bool{}
	for _, v := range cveIDs {
		uniqCveIDs[v] = true
	}

	nvdCveIDs = []string{}
	jvnCveIDs = []string{}
	for cveID := range uniqCveIDs {
		d, err := r.Get(cveID)
		if err != nil {
			return nil, nil, err
		}
		if err := filterCveDetailByCpeURI(uri, d); err != nil {
			return nil, nil, err
		}
		nvdMatch, jvnMatch, err := matchCpe(uri, d)
		if err != nil {
			logger.Warnf("Failed to compare the version:%s %s %#v",
				err, uri, d)
			continue
		}
		if nvdMatch {
			nvdCveIDs = append(nvdCveIDs, d.CveID)
		} else if jvnMatch {
			jvnCveIDs = append(jvnCveIDs, d.CveID)
		}
	}

	return nvdCveIDs, jvnCveIDs, nil
}

// GetByCpeURI Select Cve information from DB.
func (r *RDBDriver) GetByCpeURI(uri string) ([]models.CveDetail, error) {
	cveIDs, err := r.getCveIDsByPartVendorProduct(uri)
	if err != nil {
		return nil, err
	}

	uniqCveIDs := map[string]bool{}
	for _, v := range cveIDs {
		uniqCveIDs[v] = true
	}

	details := []models.CveDetail{}
	for cveID := range uniqCveIDs {
		d, err := r.Get(cveID)
		if err != nil {
			return nil, err
		}
		if err := filterCveDetailByCpeURI(uri, d); err != nil {
			return nil, err
		}
		if len(d.Nvds) > 0 || len(d.Jvns) > 0 {
			details = append(details, *d)
		}
	}
	return details, nil
}

// CountJvn count jvn table
func (r *RDBDriver) CountJvn() (int, error) {
	var count int64
	if err := r.conn.Model(&models.Jvn{}).Count(&count).Error; err != nil {
		return 0, err
	}
	return int(count), nil
}

// InsertJvn inserts Cve Information into DB
func (r *RDBDriver) InsertJvn(years []string) error {
	tx := r.conn.Begin()
	defer func() {
		if re := recover(); re != nil {
			tx.Rollback()
		}
	}()
	if err := tx.Error; err != nil {
		return err
	}

	batchSize := viper.GetInt("batch-size")
	if batchSize < 1 {
		return fmt.Errorf("Failed to set batch-size. err: batch-size option is not set properly")
	}

	logger.Infof("Deleting JVN tables...")
	if err := deleteJvn(tx); err != nil {
		tx.Rollback()
		return xerrors.Errorf("Failed to deleteJvn. err: %w", err)
	}

	// {"year", "recent" or "modified": { "JVNID#CVE-ID":Jvn{} } }
	uniqCves := map[string]map[string]models.Jvn{}

	logger.Infof("Fetching CVE information from JVN(recent, modified).")
	if err := jvn.FetchConvert(uniqCves, []string{"recent", "modified"}); err != nil {
		tx.Rollback()
		return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
	}

	if len(years) == 0 {
		for y := 1998; y <= time.Now().Year(); y++ {
			years = append(years, fmt.Sprintf("%d", y))
		}
	}
	for _, year := range years {
		logger.Infof("Fetching CVE information from JVN(%s).", year)
		if err := jvn.FetchConvert(uniqCves, []string{year}); err != nil {
			tx.Rollback()
			return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
		}

		cves := []models.Jvn{}
		for _, cve := range uniqCves[year] {
			cves = append(cves, cve)
		}
		delete(uniqCves, year)

		logger.Infof("Inserting fetched CVEs(%s)...", year)
		if err := insertJvn(tx, cves, batchSize); err != nil {
			tx.Rollback()
			return xerrors.Errorf("Failed to insertJvn. err: %w", err)
		}
		logger.Infof("Refreshed %d CVEs.", len(cves))
	}

	if err := tx.Commit().Error; err != nil {
		return xerrors.Errorf("Failed to Commit Transaction. err: %w", err)
	}
	return nil
}

func deleteJvn(tx *gorm.DB) error {
	for _, table := range []interface{}{models.Jvn{}, models.JvnCvss2{}, models.JvnCvss3{}, models.JvnCpe{}, models.JvnReference{}, models.JvnCert{}} {
		if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(table).Error; err != nil {
			return xerrors.Errorf("Failed to delete old records. err: %w", err)
		}
	}
	return nil
}

func insertJvn(tx *gorm.DB, cves []models.Jvn, batchSize int) error {
	bar := pb.StartNew(len(cves))
	for _, cve := range cves {
		if err := tx.Omit("Cpes").Create(&cve).Error; err != nil {
			return xerrors.Errorf("Failed to insert. err: %w", err)
		}

		for i := range cve.Cpes {
			cve.Cpes[i].JvnID = uint(cve.ID)
		}

		for idx := range chunkSlice(len(cve.Cpes), batchSize) {
			if err := tx.Create(cve.Cpes[idx.From:idx.To]).Error; err != nil {
				return xerrors.Errorf("Failed to insert. err: %w", err)
			}
		}
		bar.Increment()
	}
	bar.Finish()

	return nil
}

// CountNvd count nvd table
func (r *RDBDriver) CountNvd() (int, error) {
	var count int64
	if err := r.conn.Model(&models.Nvd{}).Count(&count).Error; err != nil {
		return 0, err
	}
	return int(count), nil
}

// InsertNvd Cve information from DB.
func (r *RDBDriver) InsertNvd(years []string) (err error) {
	tx := r.conn.Begin()
	defer func() {
		if re := recover(); re != nil {
			tx.Rollback()
		}
	}()
	if err := tx.Error; err != nil {
		return err
	}

	batchSize := viper.GetInt("batch-size")
	if batchSize < 1 {
		return fmt.Errorf("Failed to set batch-size. err: batch-size option is not set properly")
	}

	logger.Infof("Deleting NVD tables...")
	if err := deleteNvd(tx); err != nil {
		tx.Rollback()
		return xerrors.Errorf("Failed to deleteNvd. err: %w", err)
	}

	// {"year", "recent" or "modified": { "CVE-ID": Nvd{} } }
	uniqCves := map[string]map[string]models.Nvd{}

	logger.Infof("Fetching CVE information from NVD(recent, modified).")
	if err := nvd.FetchConvert(uniqCves, []string{"recent", "modified"}); err != nil {
		tx.Rollback()
		return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
	}

	if len(years) == 0 {
		for y := 2002; y <= time.Now().Year(); y++ {
			years = append(years, fmt.Sprintf("%d", y))
		}
	}
	for _, year := range years {
		logger.Infof("Fetching CVE information from NVD(%s).", year)
		if err := nvd.FetchConvert(uniqCves, []string{year}); err != nil {
			tx.Rollback()
			return xerrors.Errorf("Failed to FetchConvert. err: %w", err)
		}

		cves := []models.Nvd{}
		for _, cve := range uniqCves[year] {
			cves = append(cves, cve)
		}
		delete(uniqCves, year)

		logger.Infof("Inserting fetched CVEs(%s)...", year)
		if err := insertNvd(tx, cves, batchSize); err != nil {
			tx.Rollback()
			return xerrors.Errorf("Failed to insertNvd. err: %w", err)
		}
		logger.Infof("Refreshed %d CVEs.", len(cves))
	}

	if err := tx.Commit().Error; err != nil {
		return xerrors.Errorf("Failed to Commit Transaction. err: %w", err)
	}
	return nil
}

func deleteNvd(tx *gorm.DB) error {
	for _, table := range []interface{}{models.Nvd{}, models.NvdDescription{}, models.NvdCvss2Extra{}, models.NvdCvss3{}, models.NvdCwe{}, models.NvdCpe{}, models.NvdEnvCpe{}, models.NvdReference{}, models.NvdCert{}} {
		if err := tx.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(table).Error; err != nil {
			return xerrors.Errorf("Failed to delete old records. err: %w", err)
		}
	}
	return nil
}

func insertNvd(tx *gorm.DB, cves []models.Nvd, batchSize int) error {
	bar := pb.StartNew(len(cves))
	for _, cve := range cves {
		if err := tx.Omit("Cpes").Create(&cve).Error; err != nil {
			return xerrors.Errorf("Failed to insert. err: %w", err)
		}

		for i := range cve.Cpes {
			cve.Cpes[i].NvdID = uint(cve.ID)
		}

		for idx := range chunkSlice(len(cve.Cpes), batchSize) {
			if err := tx.Create(cve.Cpes[idx.From:idx.To]).Error; err != nil {
				return xerrors.Errorf("Failed to insert. err: %w", err)
			}
		}
		bar.Increment()
	}
	bar.Finish()

	return nil
}
